{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import sys\n",
    "from dotenv import load_dotenv, find_dotenv\n",
    "import os\n",
    "\n",
    "sys.path.append('../../aisuite')\n",
    "\n",
    "# Load from .env file if available\n",
    "load_dotenv(find_dotenv())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make a request to model without tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from aisuite import Client\n",
    "\n",
    "client = Client()\n",
    "# Configuring Azure. Rest all providers use environment variables for their parameters.\n",
    "client.configure({\"azure\" : {\n",
    "  \"api_key\": os.environ[\"AZURE_API_KEY\"],\n",
    "  \"base_url\": \"https://aisuite-mistral-large-2407.westus3.models.ai.azure.com/v1/\",\n",
    "}})\n",
    "# model = \"anthropic:claude-3-5-sonnet-20241022\"\n",
    "# model = \"aws:mistral.mistral-7b-instruct-v0:2\"\n",
    "# model = \"azure:aisuite-mistral-large\"\n",
    "# model = \"cohere:command-r-plus\"\n",
    "# model = \"deepseek:deepseek-chat\"\n",
    "# model = \"fireworks:accounts/fireworks/models/llama-v3p1-405b-instruct\"\n",
    "# model = \"google:gemini-1.5-pro-002\"\n",
    "# model = \"groq:llama-3.3-70b-versatile\"\n",
    "# model = \"huggingface:meta-llama/Llama-3.1-8B-Instruct\"\n",
    "# model = \"mistral:mistral-large-latest\"\n",
    "# model = \"nebius:\"\n",
    "# model = \"ollama:\"\n",
    "# model = \"sambanova:Meta-Llama-3.3-70B-Instruct\"\n",
    "# model = \"together:meta-llama/Llama-3.3-70B-Instruct-Turbo\"\n",
    "# model = \"watsonx:\"\n",
    "model = \"xai:grok-2-latest\"\n",
    "\n",
    "messages = [{\n",
    "    \"role\": \"user\",\n",
    "    \"content\": \"What is the current temperature in San Francisco in Celsius?\"}]\n",
    "\n",
    "response = client.chat.completions.create(\n",
    "    model=model, messages=messages)\n",
    "\n",
    "print(\"For model: \" + model)\n",
    "print(response.choices[0].message.content)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Equip model with tools"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Mock tool functions.\n",
    "def get_current_temperature(location: str, unit: str):\n",
    "    # Simulate fetching temperature from an API\n",
    "    return {\"temperature\": 72}\n",
    "\n",
    "def get_rain_probability(location: str):\n",
    "    # Simulate fetching rain probability\n",
    "    return {\"location\": location, \"probability\": 40}\n",
    "\n",
    "# Function to get the available tools (functions) to provide to the model\n",
    "# Note: we could use decorators or utils from OpenAI to generate this.\n",
    "def get_available_tools():\n",
    "    return [\n",
    "        {   \"type\": \"function\",\n",
    "            \"function\": {\n",
    "                \"name\": \"get_current_temperature\",\n",
    "                \"description\": \"Get the current temperature for a specific location\",\n",
    "                \"parameters\": {\n",
    "                    \"type\": \"object\",\n",
    "                    \"properties\": {\n",
    "                        \"location\": {\n",
    "                            \"type\": \"string\",\n",
    "                            \"description\": \"The city and state, e.g., San Francisco, CA\"\n",
    "                        },\n",
    "                        \"unit\": {\n",
    "                            \"type\": \"string\",\n",
    "                            \"enum\": [\"Celsius\", \"Fahrenheit\"],\n",
    "                            \"description\": \"The temperature unit to use.\"\n",
    "                        }\n",
    "                    },\n",
    "                    \"required\": [\"location\", \"unit\"]\n",
    "                }\n",
    "            }\n",
    "        },\n",
    "        {\n",
    "            \"type\": \"function\",\n",
    "            \"function\": {\n",
    "                \"name\": \"get_rain_probability\",\n",
    "                \"description\": \"Get the probability of rain for a specific location\",\n",
    "                \"parameters\": {\n",
    "                    \"type\": \"object\",\n",
    "                    \"properties\": {\n",
    "                        \"location\": {\n",
    "                            \"type\": \"string\",\n",
    "                            \"description\": \"The city and state, e.g., San Francisco, CA\"\n",
    "                        }\n",
    "                    },\n",
    "                    \"required\": [\"location\"]\n",
    "                }\n",
    "            }\n",
    "        }\n",
    "    ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to process tool calls and get the result\n",
    "def handle_tool_call(tool_call):\n",
    "    function_name = tool_call.function.name\n",
    "    arguments = json.loads(tool_call.function.arguments)\n",
    "\n",
    "    # Map function names to actual tool function implementations\n",
    "    tools_map = {\n",
    "        \"get_current_temperature\": get_current_temperature,\n",
    "        \"get_rain_probability\": get_rain_probability,\n",
    "    }\n",
    "    return tools_map[function_name](**arguments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to format tool response as a message\n",
    "def create_tool_response_message(tool_call, tool_result):\n",
    "    return {\n",
    "        \"role\": \"tool\",\n",
    "        \"tool_call_id\": tool_call.id,\n",
    "        \"name\": tool_call.function.name,\n",
    "        \"content\": json.dumps(tool_result)\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Call the model with tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import sys\n",
    "from dotenv import load_dotenv, find_dotenv\n",
    "import os\n",
    "\n",
    "sys.path.append('../../aisuite')\n",
    "\n",
    "# Load from .env file if available\n",
    "load_dotenv(find_dotenv())\n",
    "\n",
    "from aisuite import Client\n",
    "\n",
    "client = Client()\n",
    "client.configure({\"azure\" : {\n",
    "  \"api_key\": os.environ[\"AZURE_API_KEY\"],\n",
    "  \"base_url\": \"https://aisuite-mistral-large-2407.westus3.models.ai.azure.com/v1/\",\n",
    "}})\n",
    "\n",
    "# model = \"anthropic:claude-3-5-sonnet-20241022\"\n",
    "# model = \"aws:mistral.mistral-7b-instruct-v0:2\"\n",
    "# model = \"azure:aisuite-mistral-large\"\n",
    "# model = \"cohere:command-r-plus\"\n",
    "# model = \"deepseek:deepseek-chat\"\n",
    "# model = \"fireworks:accounts/fireworks/models/llama-v3p1-405b-instruct\"\n",
    "# model = \"google:gemini-1.5-pro-002\"\n",
    "# model = \"groq:llama-3.3-70b-versatile\"\n",
    "# model = \"huggingface:meta-llama/Llama-3.1-8B-Instruct\"\n",
    "# model = \"mistral:mistral-large-latest\"\n",
    "# model = \"nebius:\"\n",
    "# model = \"ollama:\"\n",
    "# model = \"sambanova:Meta-Llama-3.3-70B-Instruct\"\n",
    "# model = \"together:meta-llama/Llama-3.3-70B-Instruct-Turbo\"\n",
    "# model = \"watsonx:\"\n",
    "model = \"xai:grok-2-latest\"\n",
    "\n",
    "messages = [{\n",
    "    \"role\": \"user\",\n",
    "    \"content\": \"What is the current temperature in San Francisco in Celsius?\"}]\n",
    "\n",
    "tools = get_available_tools()\n",
    "\n",
    "# Make the initial request to OpenAI API\n",
    "response = client.chat.completions.create(\n",
    "    model=model, messages=messages, tools=tools)\n",
    "\n",
    "print(response)\n",
    "print(response.choices[0].message)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Process tool calls - Parse tool name, args, and call the function. Pass the result to the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if response.choices[0].message.tool_calls:\n",
    "    for tool_call in response.choices[0].message.tool_calls:\n",
    "        tool_result = handle_tool_call(tool_call)\n",
    "        print(tool_result)\n",
    "\n",
    "        messages.append(response.choices[0].message) # Model's function call message\n",
    "        messages.append(create_tool_response_message(tool_call, tool_result))\n",
    "        # Send the tool response back to the model\n",
    "        final_response = client.chat.completions.create(\n",
    "            model=model, messages=messages, tools=tools)\n",
    "        print(final_response.choices[0].message)\n",
    "        \n",
    "        # Output the final response from the model\n",
    "        print(final_response.choices[0].message.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
